{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "b6945fd6",
   "metadata": {
    "papermill": {
     "duration": 0.006721,
     "end_time": "2023-11-10T10:54:57.574181",
     "exception": false,
     "start_time": "2023-11-10T10:54:57.567460",
     "status": "completed"
    },
    "tags": []
   },
   "source": [
    "# GPT from Scratch using Lightning AI and Lance!\n",
    "\n",
    "This notebook follows the code that I wrote at my talk in Lightning AI meetup in London on 8th November.\n",
    "I am implementing a GPT model from scratch (including all the modules like CausalAttention, MultiHeadedAttention and FFN and then binding it all together in LightningAI and training it with the help of it.\n",
    "\n",
    "### Notes on Data Tokenization and Lance\n",
    "I am using [Lance](https://github.com/lancedb/lance/) to load our training data. It is a modern columnar data format for ML and LLMs implemented in Rust.\n",
    "\n",
    "The problem faced was that because of low memory and compute, we can't load the entire TinyStories dataset (about 2.3 GB in size) and tokenize it. The solution was to pre-tokenize the dataset, convert it into a PyArrow table and save it in Lance format.\n",
    "\n",
    "**Lance essentially allows us to only load some indices of the data at any given moment instead of loading the entire dataset and maxing out the memory**.\n",
    "\n",
    "If you want to play with Lance and checkout other use cases of it, you can see it's repository: https://github.com/lancedb/lance/"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "67f3feec-6f66-4581-93c1-674ba847b1ce",
   "metadata": {},
   "source": [
    "### Installing dependencies"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "fb30a04e",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-11-10T10:54:57.588247Z",
     "iopub.status.busy": "2023-11-10T10:54:57.587845Z",
     "iopub.status.idle": "2023-11-10T10:55:34.609038Z",
     "shell.execute_reply": "2023-11-10T10:55:34.608126Z"
    },
    "papermill": {
     "duration": 37.031058,
     "end_time": "2023-11-10T10:55:34.611457",
     "exception": false,
     "start_time": "2023-11-10T10:54:57.580399",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "%%sh\n",
    "pip install -q pyarrow\n",
    "pip install -q pylance\n",
    "pip install -q lightning"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "db7d12f5-c501-4085-93db-6901b3805a3a",
   "metadata": {},
   "source": [
    "### Importing libraries"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "9017101c",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-11-10T10:55:34.625879Z",
     "iopub.status.busy": "2023-11-10T10:55:34.625287Z",
     "iopub.status.idle": "2023-11-10T10:55:51.126992Z",
     "shell.execute_reply": "2023-11-10T10:55:51.126147Z"
    },
    "id": "83ba6687",
    "papermill": {
     "duration": 16.511123,
     "end_time": "2023-11-10T10:55:51.129304",
     "exception": false,
     "start_time": "2023-11-10T10:55:34.618181",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/opt/conda/lib/python3.10/site-packages/scipy/__init__.py:146: UserWarning: A NumPy version >=1.16.5 and <1.23.0 is required for this version of SciPy (detected version 1.23.5\n",
      "  warnings.warn(f\"A NumPy version >={np_minversion} and <{np_maxversion}\"\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "01155039c03042cc9a22b8073c5829f4",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Downloading (…)olve/main/vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "0dbe530af664490f8c2c716e69c170a2",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Downloading (…)olve/main/merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "b9a9083b84f5455784a23816953afb6a",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Downloading (…)/main/tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "d667e36ba38b4b3389b78acb944b819b",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Downloading (…)lve/main/config.json:   0%|          | 0.00/665 [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "import numpy as np\n",
    "\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "from torch.utils.data import Dataset, DataLoader\n",
    "\n",
    "import lightning\n",
    "\n",
    "import lance\n",
    "import pyarrow as pa\n",
    "\n",
    "from tqdm.auto import tqdm\n",
    "from transformers import GPT2TokenizerFast\n",
    "\n",
    "tokenizer = GPT2TokenizerFast.from_pretrained(\"gpt2\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "22ab3fae",
   "metadata": {
    "id": "96cc4ab3",
    "papermill": {
     "duration": 0.007297,
     "end_time": "2023-11-10T10:55:51.143736",
     "exception": false,
     "start_time": "2023-11-10T10:55:51.136439",
     "status": "completed"
    },
    "tags": []
   },
   "source": [
    "## Dataset and Creation"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ca58e5d1",
   "metadata": {
    "papermill": {
     "duration": 0.006535,
     "end_time": "2023-11-10T10:55:51.157677",
     "exception": false,
     "start_time": "2023-11-10T10:55:51.151142",
     "status": "completed"
    },
    "tags": []
   },
   "source": [
    "We load the TinyStories dataset and then tokenize 100K sentences from it and save it as a PyArrow table in a lance file."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "453c6712",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-11-10T10:55:51.173146Z",
     "iopub.status.busy": "2023-11-10T10:55:51.172437Z",
     "iopub.status.idle": "2023-11-10T10:57:23.558625Z",
     "shell.execute_reply": "2023-11-10T10:57:23.557751Z"
    },
    "id": "4c38da78",
    "papermill": {
     "duration": 92.396101,
     "end_time": "2023-11-10T10:57:23.560486",
     "exception": false,
     "start_time": "2023-11-10T10:55:51.164385",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Downloading and preparing dataset text/roneneldan--TinyStories to /root/.cache/huggingface/datasets/text/roneneldan--TinyStories-e7877524f0320955/0.0.0/4b86d314f7236db91f0a0f5cda32d4375445e64c5eda2692655dd99c2dac68e8...\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "f6d7c721944a4736872c3a5d4182032c",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Downloading data files:   0%|          | 0/1 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "8b8486d1400b425cb69c7f693be4c438",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Downloading data:   0%|          | 0.00/2.23G [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "05835733885445769a15584bac14b6ec",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Extracting data files:   0%|          | 0/1 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Dataset text downloaded and prepared to /root/.cache/huggingface/datasets/text/roneneldan--TinyStories-e7877524f0320955/0.0.0/4b86d314f7236db91f0a0f5cda32d4375445e64c5eda2692655dd99c2dac68e8. Subsequent calls will reuse this data.\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "7ba7231bb273483faffb49927f6ea428",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/1 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "from datasets import load_dataset\n",
    "\n",
    "dataset = load_dataset(\n",
    "    \"roneneldan/TinyStories\", data_files={\"train\": \"TinyStoriesV2-GPT4-train.txt\"}\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "acc4287f",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-11-10T10:57:23.579172Z",
     "iopub.status.busy": "2023-11-10T10:57:23.578079Z",
     "iopub.status.idle": "2023-11-10T10:57:25.302724Z",
     "shell.execute_reply": "2023-11-10T10:57:25.301547Z"
    },
    "id": "04b01749",
    "papermill": {
     "duration": 1.736169,
     "end_time": "2023-11-10T10:57:25.304706",
     "exception": false,
     "start_time": "2023-11-10T10:57:23.568537",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "36f758a2075e402ea0a87604c8dd367a",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/1000 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Total tokens in tokenized dataset: 31,603\n"
     ]
    }
   ],
   "source": [
    "# Join 'total_rows' number of sentences one after another\n",
    "all_tokens = []\n",
    "# Only choosing 1K sentences for now. Increase if you want to train it for longer on larger hardware\n",
    "total_rows = 1000\n",
    "data = dataset[\"train\"].select([x for x in range(total_rows)])\n",
    "for row in tqdm(data[\"text\"], total=len(data)):\n",
    "    row = row.replace(\"<|endoftext|>\", \" \")\n",
    "    encoded = tokenizer(row)[\"input_ids\"]\n",
    "    all_tokens.extend(encoded)\n",
    "\n",
    "pa_table = pa.Table.from_arrays([all_tokens], names=[\"value\"])\n",
    "lance.write_dataset(pa_table, \"tiny_stories_gpt4_encoded.lance\", {\"model\": \"create\"})\n",
    "\n",
    "print(f\"Total tokens in tokenized dataset: {len(all_tokens):,.0f}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2a60f566",
   "metadata": {
    "id": "c8e601b2",
    "papermill": {
     "duration": 0.007743,
     "end_time": "2023-11-10T10:57:25.320811",
     "exception": false,
     "start_time": "2023-11-10T10:57:25.313068",
     "status": "completed"
    },
    "tags": []
   },
   "source": [
    "## Model and Training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "8784525d",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-11-10T10:57:25.338347Z",
     "iopub.status.busy": "2023-11-10T10:57:25.337621Z",
     "iopub.status.idle": "2023-11-10T10:57:25.343385Z",
     "shell.execute_reply": "2023-11-10T10:57:25.342667Z"
    },
    "id": "e44ce1d6",
    "papermill": {
     "duration": 0.016455,
     "end_time": "2023-11-10T10:57:25.345226",
     "exception": false,
     "start_time": "2023-11-10T10:57:25.328771",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "class Config:\n",
    "    vocab_size = 50304  # changing it from 50257 to the nearest multiple of 64 which will boost ops\n",
    "    n_epochs = 50\n",
    "    batch_size = 48\n",
    "    lr = 3e-4\n",
    "    wd = 1e-6\n",
    "    n_embed = 256\n",
    "    num_blocks = 12\n",
    "    num_heads = 12\n",
    "    head_size = n_embed // num_heads\n",
    "    context_len = 224\n",
    "    attn_dropout_val = 0.2\n",
    "    mha_dropout_val = 0.2\n",
    "    ffn_dropout_val = 0.2"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fcdc0a38",
   "metadata": {
    "papermill": {
     "duration": 0.007707,
     "end_time": "2023-11-10T10:57:25.361161",
     "exception": false,
     "start_time": "2023-11-10T10:57:25.353454",
     "status": "completed"
    },
    "tags": []
   },
   "source": [
    "## Attention - `CausalAttentionHead`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "fe3a229f",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-11-10T10:57:25.378863Z",
     "iopub.status.busy": "2023-11-10T10:57:25.378318Z",
     "iopub.status.idle": "2023-11-10T10:57:25.387393Z",
     "shell.execute_reply": "2023-11-10T10:57:25.386598Z"
    },
    "id": "3154cb9e",
    "papermill": {
     "duration": 0.020237,
     "end_time": "2023-11-10T10:57:25.389401",
     "exception": false,
     "start_time": "2023-11-10T10:57:25.369164",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "class CausalAttentionHead(nn.Module):\n",
    "    def __init__(self, config):\n",
    "        super(CausalAttentionHead, self).__init__()\n",
    "        self.config = config\n",
    "\n",
    "        # QKV layers\n",
    "        self.query = nn.Linear(config.n_embed, config.head_size, bias=False)\n",
    "        self.key = nn.Linear(config.n_embed, config.head_size, bias=False)\n",
    "        self.value = nn.Linear(config.n_embed, config.head_size, bias=False)\n",
    "        self.attn_drop = nn.Dropout(config.attn_dropout_val)\n",
    "\n",
    "        # Mask for ensuring causality during training\n",
    "        self.register_buffer(\n",
    "            \"mask\", torch.tril(torch.ones(config.context_len, config.context_len))\n",
    "        )\n",
    "\n",
    "    def forward(self, x):\n",
    "        # Shape of x: [bs, context_len, embed_dim]\n",
    "        bs, context_len, embed_dim = x.shape\n",
    "        q, k, v = self.query(x), self.key(x), self.value(x)\n",
    "\n",
    "        # Get the attention weights\n",
    "        attn_filter = torch.divide(\n",
    "            torch.bmm(q, k.transpose(1, 2)), self.config.head_size\n",
    "        )\n",
    "        attn_filter = attn_filter.masked_fill(\n",
    "            self.mask[:context_len, :context_len] == 0, float(\"-inf\")\n",
    "        )\n",
    "        attn_weights = F.softmax(attn_filter, dim=-1)\n",
    "        attn_weights = self.attn_drop(attn_weights)\n",
    "\n",
    "        # Now we do weighted aggregation of values to get the output of attention\n",
    "        # attn_weights [bs, c, c] x V [bs, c, h] = output [bs, c, head_size]\n",
    "        output = torch.bmm(attn_weights, v)\n",
    "        return output"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fd6bd25c-9219-44c0-a383-cf46e8fd16f7",
   "metadata": {},
   "source": [
    "### `MultiHeadedAttention`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "c5b28713",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-11-10T10:57:25.406542Z",
     "iopub.status.busy": "2023-11-10T10:57:25.405918Z",
     "iopub.status.idle": "2023-11-10T10:57:25.412674Z",
     "shell.execute_reply": "2023-11-10T10:57:25.411864Z"
    },
    "id": "10e4a347",
    "papermill": {
     "duration": 0.017262,
     "end_time": "2023-11-10T10:57:25.414512",
     "exception": false,
     "start_time": "2023-11-10T10:57:25.397250",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "class MultiHeadedAttention(nn.Module):\n",
    "    def __init__(self, config):\n",
    "        super(MultiHeadedAttention, self).__init__()\n",
    "        self.config = config\n",
    "\n",
    "        # Turn all the AttentionHeads into a ModuleList\n",
    "        self.heads = nn.ModuleList(\n",
    "            [CausalAttentionHead(config) for _ in range(config.num_heads)]\n",
    "        )\n",
    "\n",
    "        # Projection and Dropout that projects mha_output it back to n_embed dim\n",
    "        self.proj = nn.Linear(config.num_heads * config.head_size, config.n_embed)\n",
    "        self.mha_drop = nn.Dropout(config.mha_dropout_val)\n",
    "\n",
    "    def forward(self, x):\n",
    "        # Concatenate all the attention head outputs together\n",
    "        mha_output = torch.cat([head(x) for head in self.heads], dim=-1)\n",
    "        return self.mha_drop(self.proj(mha_output))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d68e5d18",
   "metadata": {
    "papermill": {
     "duration": 0.007715,
     "end_time": "2023-11-10T10:57:25.430287",
     "exception": false,
     "start_time": "2023-11-10T10:57:25.422572",
     "status": "completed"
    },
    "tags": []
   },
   "source": [
    "## FeedForward Network"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "5e8e765a",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-11-10T10:57:25.447220Z",
     "iopub.status.busy": "2023-11-10T10:57:25.446731Z",
     "iopub.status.idle": "2023-11-10T10:57:25.452482Z",
     "shell.execute_reply": "2023-11-10T10:57:25.451684Z"
    },
    "id": "36d775d5",
    "papermill": {
     "duration": 0.016276,
     "end_time": "2023-11-10T10:57:25.454311",
     "exception": false,
     "start_time": "2023-11-10T10:57:25.438035",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "class FeedForwardNet(nn.Module):\n",
    "    def __init__(self, config):\n",
    "        super(FeedForwardNet, self).__init__()\n",
    "\n",
    "        self.ffn = nn.Sequential(\n",
    "            nn.Linear(config.n_embed, config.n_embed * 4),\n",
    "            nn.GELU(),\n",
    "            nn.Linear(config.n_embed * 4, config.n_embed),\n",
    "            nn.Dropout(),\n",
    "        )\n",
    "\n",
    "    def forward(self, x):\n",
    "        return self.ffn(x)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b57b0e78",
   "metadata": {
    "papermill": {
     "duration": 0.007695,
     "end_time": "2023-11-10T10:57:25.469797",
     "exception": false,
     "start_time": "2023-11-10T10:57:25.462102",
     "status": "completed"
    },
    "tags": []
   },
   "source": [
    "## One Single Block of the GPT model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "a4f2fbec",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-11-10T10:57:25.487729Z",
     "iopub.status.busy": "2023-11-10T10:57:25.487160Z",
     "iopub.status.idle": "2023-11-10T10:57:25.493311Z",
     "shell.execute_reply": "2023-11-10T10:57:25.492498Z"
    },
    "id": "79f13d36",
    "papermill": {
     "duration": 0.017686,
     "end_time": "2023-11-10T10:57:25.495117",
     "exception": false,
     "start_time": "2023-11-10T10:57:25.477431",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "class Block(nn.Module):\n",
    "    def __init__(self, config):\n",
    "        super(Block, self).__init__()\n",
    "\n",
    "        # Architecture of one block of GPT\n",
    "        self.mha = MultiHeadedAttention(config)\n",
    "        self.ln1 = nn.LayerNorm(config.n_embed)\n",
    "        self.ffn = FeedForwardNet(config)\n",
    "        self.ln2 = nn.LayerNorm(config.n_embed)\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = self.ln1(x + self.mha(x))\n",
    "        x = self.ln2(x + self.ffn(x))\n",
    "        return x"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c6ca6fea",
   "metadata": {
    "papermill": {
     "duration": 0.007724,
     "end_time": "2023-11-10T10:57:25.510585",
     "exception": false,
     "start_time": "2023-11-10T10:57:25.502861",
     "status": "completed"
    },
    "tags": []
   },
   "source": [
    "## Entire GPT model, end-to-end"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "b8adb5ab",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-11-10T10:57:25.527293Z",
     "iopub.status.busy": "2023-11-10T10:57:25.527011Z",
     "iopub.status.idle": "2023-11-10T10:57:25.542816Z",
     "shell.execute_reply": "2023-11-10T10:57:25.541396Z"
    },
    "id": "5c1378f1",
    "papermill": {
     "duration": 0.026406,
     "end_time": "2023-11-10T10:57:25.544682",
     "exception": false,
     "start_time": "2023-11-10T10:57:25.518276",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "class GPT(lightning.LightningModule):\n",
    "    def __init__(self, config):\n",
    "        super(GPT, self).__init__()\n",
    "        self.config = config\n",
    "        self.save_hyperparameters()\n",
    "\n",
    "        # Define token and positional embeddings\n",
    "        self.token_embedding = nn.Embedding(config.vocab_size, config.n_embed)\n",
    "        self.positional_embedding = nn.Embedding(config.context_len, config.n_embed)\n",
    "\n",
    "        # Define the blocks\n",
    "        self.backbone = nn.Sequential(\n",
    "            *[Block(config) for _ in range(config.num_blocks)]\n",
    "        )\n",
    "\n",
    "        # Define the LM head\n",
    "        self.lm_head = nn.Linear(config.n_embed, config.vocab_size)\n",
    "\n",
    "    def forward(self, x):\n",
    "        # Apply token embeddings through the data (B, C) -> (B, C, V)\n",
    "        tok_emb = self.token_embedding(x)\n",
    "\n",
    "        # Get positional embeddings using torch.arange\n",
    "        pos_emb = self.positional_embedding(\n",
    "            torch.arange(x.shape[1], device=self.device)\n",
    "        )\n",
    "\n",
    "        # Add both embeddings\n",
    "        x = tok_emb + pos_emb\n",
    "\n",
    "        # Pass the input data through all blocks\n",
    "        x = self.backbone(x)\n",
    "\n",
    "        # Pass it through the lm head\n",
    "        logits = self.lm_head(x)\n",
    "        return logits\n",
    "\n",
    "    def get_loss(self, predictions, target):\n",
    "        B, C, V = predictions.shape\n",
    "        predictions = predictions.view(B * C, V)\n",
    "        target = target.view(B * C)\n",
    "        loss = F.cross_entropy(predictions, target)\n",
    "        return loss\n",
    "\n",
    "    def training_step(self, batch, batch_idx):\n",
    "        text, target = batch\n",
    "        text = text.long()\n",
    "        target = target.long()\n",
    "        logits = self(text)\n",
    "        loss = self.get_loss(logits, target)\n",
    "\n",
    "        self.log(\"loss\", loss.item(), prog_bar=True)\n",
    "\n",
    "        logs = {\"loss\": loss}\n",
    "        return {\"log\": logs, \"loss\": loss}\n",
    "\n",
    "    def training_end(self, outputs):\n",
    "        avg_loss = torch.stack([x[\"log\"][\"loss\"] for x in outputs]).mean()\n",
    "\n",
    "        logs = {\"loss\": avg_loss}\n",
    "\n",
    "        print(f\"val_loss: {avg_loss}\")\n",
    "        return {\"log\": logs}\n",
    "\n",
    "    def configure_optimizers(self):\n",
    "        opt = torch.optim.AdamW(\n",
    "            self.parameters(), lr=self.config.lr, weight_decay=self.config.wd\n",
    "        )\n",
    "        return [opt], []\n",
    "\n",
    "\n",
    "def generate(model, prompt, max_tokens, temperature=0.7):\n",
    "    \"\"\"\n",
    "    Generates text based on the provided prompt.\n",
    "    Model determinism can be changed with temperature\n",
    "    (range: [0, 1], higher means more unstable but creative predictions)\n",
    "    \"\"\"\n",
    "    model.eval()\n",
    "    for _ in range(max_tokens):\n",
    "        prompt = prompt[:, : config.context_len]\n",
    "        logits = model(prompt)\n",
    "        logits = logits[:, -1, :] / temperature\n",
    "        logit_probs = nn.functional.softmax(logits, dim=-1)\n",
    "        next_prompt = torch.multinomial(logit_probs, num_samples=1)\n",
    "        prompt = torch.cat((prompt, next_prompt), dim=1)\n",
    "    return prompt"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bf97a30d",
   "metadata": {
    "papermill": {
     "duration": 0.007576,
     "end_time": "2023-11-10T10:57:25.560370",
     "exception": false,
     "start_time": "2023-11-10T10:57:25.552794",
     "status": "completed"
    },
    "tags": []
   },
   "source": [
    "## GPTDataset \n",
    "for efficient and fast data loading thanks to **Lance**!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "21290bb2",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-11-10T10:57:25.577800Z",
     "iopub.status.busy": "2023-11-10T10:57:25.577265Z",
     "iopub.status.idle": "2023-11-10T10:57:25.584726Z",
     "shell.execute_reply": "2023-11-10T10:57:25.583825Z"
    },
    "id": "88bc8b84",
    "papermill": {
     "duration": 0.018316,
     "end_time": "2023-11-10T10:57:25.586487",
     "exception": false,
     "start_time": "2023-11-10T10:57:25.568171",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "class GPTDataset(Dataset):\n",
    "    def __init__(self, dataset_path, context_len):\n",
    "        # Load the lance dataset from the saved path\n",
    "        self.ds = lance.dataset(dataset_path)\n",
    "        self.context_len = context_len\n",
    "        # Doing this so the sampler never asks for an index at the end of text\n",
    "        self.length = self.ds.count_rows() - context_len\n",
    "\n",
    "    def __len__(self):\n",
    "        return self.length\n",
    "\n",
    "    def from_idxs(self, idxs):\n",
    "        \"\"\"\n",
    "        Little Utility function to get the data from lance\n",
    "        \"\"\"\n",
    "        data = self.ds.take(idxs).to_pylist()\n",
    "        data = torch.tensor(list(map(lambda x: x[\"value\"], data)))\n",
    "        return data\n",
    "\n",
    "    def __getitem__(self, idx):\n",
    "        \"\"\"\n",
    "        Generate a list of indices starting from the current idx to idx+context_len+1\n",
    "        Use the from_idxs function to get data in said indexes and then divide it into features (x) and target (y)\n",
    "        \"\"\"\n",
    "        current_window_idxs = np.arange(idx, idx + self.context_len + 1)\n",
    "        data = self.from_idxs(current_window_idxs)\n",
    "        x = data[0 : self.context_len]\n",
    "        y = data[\n",
    "            1 : self.context_len + 1\n",
    "        ]  # +1 because our target is the sentence is 1 step ahead of input text\n",
    "        return x, y"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "11ff2a0b",
   "metadata": {
    "papermill": {
     "duration": 0.00754,
     "end_time": "2023-11-10T10:57:25.602101",
     "exception": false,
     "start_time": "2023-11-10T10:57:25.594561",
     "status": "completed"
    },
    "tags": []
   },
   "source": [
    "## Finally, let's train the model!\n",
    "\n",
    "We'll train the model for 50 epochs which should take ~5 hours to train. Change the number of epochs and other hyperparams in the `Config` class if you are training for longer and on more powerful hardware."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "a5989fc0",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-11-10T10:57:25.618915Z",
     "iopub.status.busy": "2023-11-10T10:57:25.618656Z",
     "iopub.status.idle": "2023-11-10T15:53:09.002656Z",
     "shell.execute_reply": "2023-11-10T15:53:09.001693Z"
    },
    "id": "23ac470d",
    "papermill": {
     "duration": 17743.395025,
     "end_time": "2023-11-10T15:53:09.005041",
     "exception": false,
     "start_time": "2023-11-10T10:57:25.610016",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO: GPU available: True (cuda), used: True\n",
      "INFO: TPU available: False, using: 0 TPU cores\n",
      "INFO: IPU available: False, using: 0 IPUs\n",
      "INFO: HPU available: False, using: 0 HPUs\n",
      "WARNING: Missing logger folder: /kaggle/working/lightning_logs\n",
      "INFO: LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n",
      "INFO: \n",
      "  | Name                 | Type       | Params\n",
      "----------------------------------------------------\n",
      "0 | token_embedding      | Embedding  | 12.9 M\n",
      "1 | positional_embedding | Embedding  | 57.3 K\n",
      "2 | backbone             | Sequential | 9.4 M \n",
      "3 | lm_head              | Linear     | 12.9 M\n",
      "----------------------------------------------------\n",
      "35.3 M    Trainable params\n",
      "0         Non-trainable params\n",
      "35.3 M    Total params\n",
      "141.128   Total estimated model params size (MB)\n",
      "/opt/conda/lib/python3.10/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=3` in the `DataLoader` to improve performance.\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "4f8917be126e4849bf8f7403ab594d76",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Training: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO: `Trainer.fit` stopped: `max_epochs=50` reached.\n"
     ]
    }
   ],
   "source": [
    "if __name__ == \"__main__\":\n",
    "    # Path of the encoded lance dataset\n",
    "    dataset_path = \"tiny_stories_gpt4_encoded.lance\"\n",
    "\n",
    "    # Init config\n",
    "    config = Config()\n",
    "\n",
    "    # Init model\n",
    "    gpt = GPT(config)\n",
    "\n",
    "    # Init the dataset\n",
    "    dataset = GPTDataset(dataset_path, config.context_len)\n",
    "    loader = DataLoader(\n",
    "        dataset,\n",
    "        batch_size=config.batch_size,\n",
    "        shuffle=False,\n",
    "    )\n",
    "\n",
    "    # Init the trainer\n",
    "    trainer = lightning.Trainer(accelerator=\"auto\", max_epochs=config.n_epochs)\n",
    "\n",
    "    # Fit on the data\n",
    "    trainer.fit(gpt, loader)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b83f8a9d",
   "metadata": {
    "papermill": {
     "duration": 0.008662,
     "end_time": "2023-11-10T15:53:09.022774",
     "exception": false,
     "start_time": "2023-11-10T15:53:09.014112",
     "status": "completed"
    },
    "tags": []
   },
   "source": [
    "## Generate some text!\n",
    "\n",
    "Let's see how much our model learnt about the text data by asking it to generate some text given a prompt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "8514bb14",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-11-10T15:53:09.041052Z",
     "iopub.status.busy": "2023-11-10T15:53:09.040747Z",
     "iopub.status.idle": "2023-11-10T15:53:19.721245Z",
     "shell.execute_reply": "2023-11-10T15:53:19.720149Z"
    },
    "id": "e49d57f6",
    "papermill": {
     "duration": 10.692037,
     "end_time": "2023-11-10T15:53:19.723323",
     "exception": false,
     "start_time": "2023-11-10T15:53:09.031286",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "My cat is very special. She named her angel Lily.One day, Anna and Lily went to the park with her mom. They saw a big slide, a swing, and a sandbox. Anna wanted to play with everything. Anna has a swing, but they were playing in the slide, but they all lived happily in the slide. She asked her mom, \"Can I go on the slide, mom?\"\"Yes, Anna. She likes to play too,\" her mom said.Anna nodded and slid down. She laughed and said, \"Whee! That was fun, Lily!\"Anna was fun, \"Look, you can fly like a real angel!\"Then she went to the sandbox. She pushed Lily on top and said, \"Look, Lily, Lily. She is my angel!\"Anna was having a lot of fun. She put Lily on top and said, \"You are the castle, Lily!\"Anna was having a lot of fun. But she did not see the unknown boy who came to the sandbox. He was bigger than Anna and wanted to take her. He saw Lily from the castle and\n"
     ]
    }
   ],
   "source": [
    "# Generate some predictions\n",
    "prompt = \"My cat is\"  # Change the prompt to whatever you want\n",
    "\n",
    "tokenizer = GPT2TokenizerFast.from_pretrained(\"gpt2\")\n",
    "gpt = gpt.to(\"cuda\")\n",
    "prompt = tokenizer.encode(prompt, return_tensors=\"pt\").to(\"cuda\")\n",
    "generated_text = generate(gpt, prompt, max_tokens=config.context_len, temperature=0.7)\n",
    "generated_text = tokenizer.decode(generated_text.tolist()[0])\n",
    "print(generated_text)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.1"
  },
  "papermill": {
   "default_parameters": {},
   "duration": 17908.913335,
   "end_time": "2023-11-10T15:53:23.056230",
   "environment_variables": {},
   "exception": null,
   "input_path": "__notebook__.ipynb",
   "output_path": "__notebook__.ipynb",
   "parameters": {},
   "start_time": "2023-11-10T10:54:54.142895",
   "version": "2.4.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
